import numpy
import numpy.random
import re
def asm(*args):
  tstr = args[0] + "\t" + ", ".join(args[1:])
  return tstr.expandtabs()
VSWAP = "%[VSWAP]"
VCMP = "%[VCMP]"
AUX = "%[AUX]"
VFLIP = lambda arg: "%%[VFLIP%s]" % arg
VSHFF = "vshff"
VSUBL = "vsubl"
VLOG3 = "vlog3r0"
VSELLT = "vsellt"
VCPYS = "vcpys"
vsellt = lambda cond, va, vb, vc: asm(VSELLT, cond, va, vb, vc)
vselgt = lambda cond, va, vb, vc: asm(VSELLT, cond, vb, va, vc)
def bitonic_vec(v, order):
  shff_pos = [0, 0, 0, 0]
  selection = ["0", "0", "0", "0"]
  for i in range(4):
    for j in range(i + 1, 4):
      if (order[i] ^ 1) == order[j]:
        shff_pos[i] = j
        shff_pos[j] = i
        if order[i] < order[j]:
          selection[i] = "0"
          selection[j] = "1"
        else:
          selection[i] = "1"
          selection[j] = "0"
  flip_mask = int("".join(selection[::-1]), 2)
  vsel = vsellt
  if flip_mask > (15 ^ flip_mask):
    flip_mask = 15 ^ flip_mask
    vsel = vselgt
  shff_mask = shff_pos[0] | (shff_pos[1] << 2) | (shff_pos[2] << 4) | (shff_pos[3] << 6)
  
  insts = [
    # v = [1, 2, 3, 4]
    asm(VSHFF, v, v, hex(shff_mask), VSWAP),
    # VSWAP = [2, 1, 4, 3]
    asm(VSUBL, v, VSWAP, VCMP),
    # VCMP = [-1, 1, -1, 1]
    asm(VLOG3, AUX, VFLIP("%x" % flip_mask), VCMP, VCMP),
    # VCMP = [-1, -1, 1, 1]
    vsel(VCMP, v, VSWAP, v)
    # v = [v[0], v[1], VSWAP[2], VSWAP[3]] = [1, 2, 4, 3]
  ]
  return insts
def minmax_vec(v0, v1, reverse = False):
  sel = vsellt
  if reverse:
    sel = vselgt
  insts = [
    asm(VSUBL, v0, v1, VCMP),
    asm(VCPYS, VCMP, VFLIP("*"), VCMP),
    sel(VCMP, v1, v0, VSWAP),
    sel(VCMP, v0, v1, v0),
    asm(VCPYS, VSWAP, VSWAP, v1)
  ]
  return insts
def minmax_vec_16(reverse = False):
  V0 = "%[V0]"
  V1 = "%[V1]"
  V2 = "%[V2]"
  V3 = "%[V3]"
  ret = []
  ret.extend(minmax_vec(V0, V2, reverse))
  ret.extend(minmax_vec(V1, V3, reverse))
  return ret
def bitonic_vec_init_8():
  V0 = "%[V0]"
  V1 = "%[V1]"
  ret = []

  ret.extend(bitonic_vec(V0, [0, 1, 3, 2]))
  ret.extend(bitonic_vec(V1, [0, 1, 3, 2]))
  ret.extend(bitonic_vec(V0, [0, 2, 1, 3]))
  ret.extend(bitonic_vec(V1, [1, 3, 0, 2]))
  ret.extend(bitonic_vec(V0, [0, 1, 2, 3]))
  ret.extend(bitonic_vec(V1, [1, 0, 3, 2]))
  return ret
def bitonic_inc_8():
  V0 = "%[V0]"
  V1 = "%[V1]"
  ret = []
  ret.extend(minmax_vec(V0, V1))
  ret.extend(bitonic_vec(V0, [0, 2, 1, 3]))
  ret.extend(bitonic_vec(V1, [0, 2, 1, 3]))
  ret.extend(bitonic_vec(V0, [0, 1, 2, 3]))
  ret.extend(bitonic_vec(V1, [0, 1, 2, 3]))
  return ret
def bitonic_dec_8():
  V0 = "%[V0]"
  V1 = "%[V1]"
  ret = []
  ret.extend(minmax_vec(V0, V1, reverse=True))
  ret.extend(bitonic_vec(V0, [1, 3, 0, 2]))
  ret.extend(bitonic_vec(V1, [1, 3, 0, 2]))
  ret.extend(bitonic_vec(V0, [1, 0, 3, 2]))
  ret.extend(bitonic_vec(V1, [1, 0, 3, 2]))
  return ret
def gen_bitonic_macro(name, insts, narg = 2):
  VFLIP_RE = re.compile(r'%\[VFLIP([0-9a-f])\]')
  flips = set([])
  for inst in insts:
    m = VFLIP_RE.search(inst)
    if m:
      flips.add(m.groups()[0])
  if len(flips):
    asr = next(iter(flips))
  else:
    asr = '3'
    flips.add('3')
  lines = []
  pars = ", ".join(map(lambda x: "v%d" % x, range(narg)))
  lines.append("#define %s(%s) {" % (name, pars))
  lines.append("  int256 t0, t1;")
  lines.append("  asm(")
  for inst in insts:
    lines.append('    "%s\\n\\t"' % inst.replace("VFLIP*", "VFLIP" + asr))
  retarg = ", ".join(map(lambda x: '[V%d] "=r"(v%d)' % (x, x), range(narg)))
  arg = ", ".join(map(lambda x: '[R%d] "%d"(v%d)' % (x, x, x), range(narg)))
  lines.append('    : %s, [VSWAP] "=&r"(t0), [VCMP] "=&r"(t1)' % retarg)
  flipargs = ", ".join(map(lambda x: '[VFLIP%s] "r"(flip%s)' % (x, x), flips))
  inputline = ['    : [AUX] "r"(aux)']
  if flipargs:
    inputline.append(flipargs)
  if arg:
    inputline.append(arg)
  lines.append(", ".join(inputline) + ");")
  lines.append('}')
  return "\\\n".join(lines)
def aux_flip_init():
  lines = ["#define AUX_FLIP_DEF "]
  for i in range(8):
    v0 = (i >> 0 & 1) << 11 | 0x3ff
    v1 = (i >> 1 & 1) << 11 | 0x3ff
    v2 = (i >> 2 & 1) << 11 | 0x3ff
    v3 = (i >> 3 & 1) << 11 | 0x3ff
    lines.append("int256 flip%x = simd_set_int256(0x%xL << 52, 0x%xL << 52, 0x%xL << 52, 0x%xL << 52);" % (i, v0, v1, v2, v3))
  lines.append("int256 aux = simd_set_int256(1L << 63, 1L << 63, 1L << 63, 1L << 63);")
  lines.append(r'asm("wcsr %0, 0x80\n\t":: "r"(0x6c));')
  return "\\\n".join(lines)
#print("\n".join(bitonic_vec_init_8()))
print(gen_bitonic_macro("bitonic_vec_init_8".upper(), bitonic_vec_init_8()))
print("")
print(gen_bitonic_macro("bitonic_inc_8".upper(), bitonic_inc_8()))
print("")
print(gen_bitonic_macro("bitonic_dec_8".upper(), bitonic_dec_8()))
print()
print(gen_bitonic_macro("bitonic_inc_4v".upper(), minmax_vec_16(), 4))
print()
print(gen_bitonic_macro("bitonic_dec_4v".upper(), minmax_vec_16(reverse=True), 4))
print(aux_flip_init())